"""
Interface for slcan compatible interfaces (win32/linux).

 ------------------------------------------------------------------
 Update Author : luojiaaoo
 Last change: 11.02.2024
 github: https://github.com/luojiaaoo
 ------------------------------------------------------------------
"""

import io
import logging
import time
from typing import Any, Optional, Tuple, List
from queue import Queue, Empty
import threading
from serial.tools import list_ports
from can import BusABC, CanProtocol, Message, typechecking

from ..exceptions import (
    CanInitializationError,
    CanInterfaceNotImplementedError,
    CanOperationError,
    error_check,
)

logger = logging.getLogger(__name__)

try:
    import serial
except ImportError:
    logger.warning(
        "You won't be able to use the slcan can backend without " "the serial module installed!"
    )
    serial = None

HAS_EVENTS = False

try:
    from _overlapped import CreateEvent
    from _winapi import WaitForSingleObject

    HAS_EVENTS = True
except ImportError:
    WaitForSingleObject = None
    HAS_EVENTS = False


class slcanBus(BusABC):
    """
    slcan interface
    """

    # the supported bitrates and their commands
    _BITRATES = {
        10000: "S0",
        20000: "S1",
        50000: "S2",
        100000: "S3",
        125000: "S4",
        250000: "S5",
        500000: "S6",
        750000: "S7",
        1000000: "S8",
        83300: "S9",
        2000000: "Y2",
        5000000: "Y5",
    }

    _SLEEP_AFTER_SERIAL_OPEN = 2  # in seconds
    _POLL_INTERVAL = 0.001

    _OK = b"\r"
    _ERROR = b"\a"

    LINE_TERMINATOR = b"\r"

    DLC2BYTE_LEN = {
        0: 0,
        1: 1,
        2: 2,
        3: 3,
        4: 4,
        5: 5,
        6: 6,
        7: 7,
        8: 8,
        9: 12,
        10: 16,
        11: 20,
        12: 24,
        13: 32,
        14: 48,
        15: 64,
    }
    BYTE_LEN2DLC = {j: i for i, j in DLC2BYTE_LEN.items()}

    def __init__(
        self,
        channel: typechecking.ChannelStr,
        ttyBaudrate: int = 115200,
        bitrate: Optional[int] = None,
        fd: bool = False,
        data_bitrate: Optional[int] = None,
        btr: Optional[str] = None,
        poll_interval: float = _POLL_INTERVAL,
        receive_own_messages: bool = False,
        sleep_after_open: float = _SLEEP_AFTER_SERIAL_OPEN,
        rtscts: bool = False,
        listen_only: bool = False,
        timeout: float = 0.001,
        **kwargs: Any,
    ) -> None:
        """
        :param str channel:
            port of underlying serial or usb device (e.g. ``/dev/ttyUSB0``, ``COM8``, ...)
            Must not be empty. Can also end with ``@115200`` (or similarly) to specify the baudrate.
        :param int ttyBaudrate:
            baudrate of underlying serial or usb device (Ignored if set via the ``channel`` parameter)
        :param bitrate:
            Bitrate in bit/s
        :param bool fd:
            If CAN-FD frames should be supported.
        :param int data_bitrate:
            Which bitrate to use for data phase in CAN FD.
        :param btr:
            BTR register value to set custom can speed
        :param receive_own_messages:
            See :class:`can.BusABC`.
        :param poll_interval:
            Poll interval in seconds when reading messages
        :param sleep_after_open:
            Time to wait in seconds after opening serial connection
        :param rtscts:
            turn hardware handshake (RTS/CTS) on and off
        :param listen_only:
            If True, open interface/channel in listen mode with ``L`` command.
            Otherwise, the (default) ``O`` command is still used. See ``open`` method.
        :param timeout:
            Timeout for the serial or usb device in seconds (default 0.001)

        :raise ValueError: if both ``bitrate`` and ``btr`` are set or the channel is invalid
        :raise CanInterfaceNotImplementedError: if the serial module is missing
        :raise CanInitializationError: if the underlying serial connection could not be established
        """
        self._listen_only = listen_only
        self.poll_interval = poll_interval
        self.receive_own_messages = receive_own_messages

        if serial is None:
            raise CanInterfaceNotImplementedError("The serial module is not installed")

        if not channel:  # if None or empty
            raise ValueError("Must specify a serial port.")
        if "@" in channel:
            (channel, baudrate) = channel.split("@")
            ttyBaudrate = int(baudrate)

        with error_check(exception_type=CanInitializationError):
            self.serialPortOrig = serial.serial_for_url(
                channel,
                baudrate=ttyBaudrate,
                rtscts=rtscts,
                timeout=timeout,
            )

        self._buffer = bytearray()
        self._can_protocol = CanProtocol.CAN_20

        time.sleep(sleep_after_open)

        with error_check(exception_type=CanInitializationError):
            if bitrate is not None and btr is not None:
                raise ValueError("Bitrate and btr mutually exclusive.")
            if bitrate is not None:
                self.set_bitrate(bitrate)
            if btr is not None:
                self.set_bitrate_reg(btr)
            if fd and data_bitrate is not None:
                self.set_data_bitrate(data_bitrate)
            self.open()

        self._timestamp_offset = time.time() - time.perf_counter()
        self.queue_read = Queue()
        self.event_read = threading.Event()
        self._recv_event = CreateEvent(None, 0, 0, None) if HAS_EVENTS else None
        threading.Thread(None, target=self._read_can_thread, args=(self.event_read,)).start()
        super().__init__(
            channel,
            ttyBaudrate=115200,
            bitrate=None,
            rtscts=False,
            **kwargs,
        )

    def set_bitrate(self, bitrate: int) -> None:
        """
        :param bitrate:
            Bitrate in bit/s

        :raise ValueError: if ``bitrate`` is not among the possible values
        """
        if bitrate in self._BITRATES:
            bitrate_code = self._BITRATES[bitrate]
        else:
            bitrates = ", ".join(str(k) for k in self._BITRATES.keys())
            raise ValueError(f"Invalid bitrate, choose one of {bitrates}.")

        self.close()
        self._write(bitrate_code)
        self.open()

    def set_data_bitrate(self, data_bitrate: int) -> None:
        """
        :param data_bitrate:
            Bitrate in bit/s

        :raise ValueError: if ``data_bitrate`` is not among the possible values
        """
        if data_bitrate in self._BITRATES:
            data_bitrate_code = self._BITRATES[data_bitrate]
        else:
            data_bitrates = ", ".join(str(k) for k in self._BITRATES.keys())
            raise ValueError(f"Invalid bitrate, choose one of {data_bitrates}.")

        self.close()
        self._write(data_bitrate_code)
        self.open()

    def set_bitrate_reg(self, btr: str) -> None:
        """
        :param btr:
            BTR register value to set custom can speed
        """
        self.close()
        self._write("s" + btr)
        self.open()

    def _write(self, string: str) -> None:
        with error_check("Could not write to serial device"):
            self.serialPortOrig.write(string.encode() + self.LINE_TERMINATOR)
            self.serialPortOrig.flush()

    def flush(self) -> None:
        self._buffer.clear()
        self.queue_read = Queue()
        with error_check("Could not flush"):
            self.serialPortOrig.reset_input_buffer()

    def open(self) -> None:
        if self._listen_only:
            self._write("L")
        else:
            self._write("O")

    def close(self) -> None:
        self._write("C")

    def _recv_internal(self, timeout: Optional[float]):
        try:
            return self.queue_read.get(timeout=timeout), False
        except Empty:
            return None, False

    def _read_can_thread(self, event_read):
        while not event_read.is_set():
            msgs = self._read_can()
            for i in msgs:
                self.queue_read.put(i)
            if HAS_EVENTS:
                WaitForSingleObject(self._recv_event, int(self.poll_interval * 1000))
            else:
                time.sleep(self.poll_interval)

    def _read_can(self) -> List[Message]:
        msgs = []
        with error_check("Could not read from serial device"):
            # Due to accessing `serialPortOrig.in_waiting` too often will reduce the performance.
            # We read the `serialPortOrig.in_waiting` only once here.
            in_waiting = self.serialPortOrig.in_waiting
            for _ in range(max(1, in_waiting)):
                new_byte = self.serialPortOrig.read(size=1)
                if new_byte:
                    self._buffer.extend(new_byte)
                else:
                    break

                if new_byte in (self._ERROR, self._OK):
                    string = self._buffer.decode()
                    self._buffer.clear()
                    if not string:
                        continue
                    canId = None
                    remote = False
                    extended = False
                    brs = False
                    fd = False
                    data = None
                    s0 = string[0]
                    if s0 in ("T", "x"):
                        # extended frame
                        canId = int(string[1:9], 16)
                        dlc = int(string[9])
                        extended = True
                        data = bytearray.fromhex(string[10 : 10 + dlc * 2])
                    elif s0 == "t":
                        # normal frame
                        canId = int(string[1:4], 16)
                        dlc = int(string[4])
                        data = bytearray.fromhex(string[5 : 5 + dlc * 2])
                    elif s0 == "r":
                        # remote frame
                        canId = int(string[1:4], 16)
                        dlc = int(string[4])
                        remote = True
                        data = bytearray.fromhex(string[5 : 5 + dlc * 2])
                    elif s0 == "R":
                        # remote extended frame
                        canId = int(string[1:9], 16)
                        dlc = int(string[9])
                        extended = True
                        remote = True
                        data = bytearray.fromhex(string[10 : 10 + dlc * 2])
                    elif s0 == "d":
                        # fd_frame
                        fd = True
                        canId = int(string[1:4], 16)
                        dlc = self.DLC2BYTE_LEN[int(string[4])]
                        data = bytearray.fromhex(string[5 : 5 + dlc * 2])
                    elif s0 == "D":
                        # extended fd_frame
                        fd = True
                        extended = True
                        canId = int(string[1:9], 16)
                        dlc = self.DLC2BYTE_LEN[int(string[9])]
                        data = bytearray.fromhex(string[10 : 10 + dlc * 2])
                    elif s0 == "b":
                        # fd_frame
                        fd = True
                        brs = True
                        canId = int(string[1:4], 16)
                        dlc = self.DLC2BYTE_LEN[int(string[4])]
                        data = bytearray.fromhex(string[5 : 5 + dlc * 2])
                    elif s0 == "B":
                        # extended fd_frame
                        fd = True
                        brs = True
                        extended = True
                        canId = int(string[1:9], 16)
                        dlc = self.DLC2BYTE_LEN[int(string[9])]
                        data = bytearray.fromhex(string[10 : 10 + dlc * 2])

                    if canId is not None:
                        msg = Message(
                            is_fd=fd,
                            bitrate_switch=brs,
                            arbitration_id=canId,
                            is_extended_id=extended,
                            timestamp=self._timestamp_offset
                            + time.perf_counter(),  # Better than nothing...
                            is_remote_frame=remote,
                            dlc=dlc,
                            data=data,
                        )
                        msgs.append(msg)
        return msgs

    def send(self, msg: Message, timeout: Optional[float] = None) -> None:
        if timeout != self.serialPortOrig.write_timeout:
            self.serialPortOrig.write_timeout = timeout
        if msg.is_fd:
            dlc = self.BYTE_LEN2DLC[msg.dlc]
            if not msg.bitrate_switch:
                if not msg.is_extended_id:
                    sendStr = f"d{msg.arbitration_id:03X}{dlc:x}"
                else:
                    sendStr = f"D{msg.arbitration_id:08X}{dlc:x}"
            else:
                if not msg.is_extended_id:
                    sendStr = f"b{msg.arbitration_id:03X}{dlc:x}"
                else:
                    sendStr = f"B{msg.arbitration_id:08X}{dlc:x}"
        else:
            if msg.is_remote_frame:
                if msg.is_extended_id:
                    sendStr = f"R{msg.arbitration_id:08X}{msg.dlc:d}"
                else:
                    sendStr = f"r{msg.arbitration_id:03X}{msg.dlc:d}"
            else:
                if msg.is_extended_id:
                    sendStr = f"T{msg.arbitration_id:08X}{msg.dlc:d}"
                else:
                    sendStr = f"t{msg.arbitration_id:03X}{msg.dlc:d}"
        sendStr += msg.data.hex().upper()
        if self.receive_own_messages:
            msg.is_rx = False
            msg.timestamp = self._timestamp_offset + time.perf_counter()  # Better than nothing...
            self.queue_read.put(msg)
        self._write(sendStr)

    def shutdown(self) -> None:
        super().shutdown()
        self.event_read.set()
        self.close()
        with error_check("Could not close serial socket"):
            self.serialPortOrig.close()

    def fileno(self) -> int:
        try:
            return self.serialPortOrig.fileno()
        except io.UnsupportedOperation:
            raise NotImplementedError(
                "fileno is not implemented using current CAN bus on this platform"
            ) from None
        except Exception as exception:
            raise CanOperationError("Cannot fetch fileno") from exception

    def _read(self, timeout: Optional[float]) -> Optional[str]:
        _timeout = serial.Timeout(timeout)

        with error_check("Could not read from serial device"):
            while True:
                # Due to accessing `serialPortOrig.in_waiting` too often will reduce the performance.
                # We read the `serialPortOrig.in_waiting` only once here.
                in_waiting = self.serialPortOrig.in_waiting
                for _ in range(max(1, in_waiting)):
                    new_byte = self.serialPortOrig.read(size=1)
                    if new_byte:
                        self._buffer.extend(new_byte)
                    else:
                        break

                    if new_byte in (self._ERROR, self._OK):
                        string = self._buffer.decode()
                        self._buffer.clear()
                        return string

                if _timeout.expired():
                    break

            return None

    def get_version(self, timeout: Optional[float]) -> Tuple[Optional[int], Optional[int]]:
        """Get HW and SW version of the slcan interface.

        :param timeout:
            seconds to wait for version or None to wait indefinitely

        :returns: tuple (hw_version, sw_version)
            WHERE
            int hw_version is the hardware version or None on timeout
            int sw_version is the software version or None on timeout
        """
        cmd = "V"
        self._write(cmd)

        string = self._read(timeout)

        if not string:
            pass
        elif string[0] == cmd and len(string) == 6:
            # convert ASCII coded version
            hw_version = int(string[1:3])
            sw_version = int(string[3:5])
            return hw_version, sw_version

        return None, None

    def get_serial_number(self, timeout: Optional[float]) -> Optional[str]:
        """Get serial number of the slcan interface.

        :param timeout:
            seconds to wait for serial number or :obj:`None` to wait indefinitely

        :return:
            :obj:`None` on timeout or a :class:`str` object.
        """
        cmd = "N"
        self._write(cmd)

        string = self._read(timeout)

        if not string:
            pass
        elif string[0] == cmd and len(string) == 6:
            serial_number = string[1:-1]
            return serial_number

        return None

    @staticmethod
    def _detect_available_configs():
        """
        Identify slcan devices
        """
        ports = []

        for p in list_ports.comports():
            ports.append((p.device, p.description))
        return [
            {
                "interface": "slcan",
                "channel": port,
                "name": des,
            }
            for port, des in ports
        ]